{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d0b72877",
   "metadata": {},
   "source": [
    "# Pre-encoding a dataset for DALLE·mini"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba7b31e6",
   "metadata": {},
   "source": [
    "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
    "\n",
    "Adapt it to your own dataset and image encoder.\n",
    "\n",
    "At the end you should have a dataset of pairs:\n",
    "* a caption defined as a string\n",
    "* an encoded image defined as a list of int."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b59489e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import torchvision.transforms as T\n",
    "\n",
    "import webdataset as wds\n",
    "\n",
    "import jax\n",
    "import braceexpand\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7c4c1e6",
   "metadata": {},
   "source": [
    "## Configuration Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1265dbfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "shards = \"my_images/shard-{0000..0008}.tar\"  # defined using braceexpand format as used by webdataset\n",
    "encoded_output = Path(\"encoded_data\")  # where we will save our encoded data\n",
    "\n",
    "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
    "    \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
    "    \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
    ")\n",
    "\n",
    "# good defaults for a TPU v3-8\n",
    "batch_size = 128  # Per device\n",
    "num_workers = 8  # For parallel processing\n",
    "total_bs = batch_size * jax.device_count()  # You can use a smaller size while testing\n",
    "save_frequency = 128  # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['XXX/shard-0000.tar',\n",
       " 'XXX/shard-0001.tar',\n",
       " 'XXX/shard-0002.tar',\n",
       " 'XXX/shard-0003.tar',\n",
       " 'XXX/shard-0004.tar',\n",
       " 'XXX/shard-0005.tar',\n",
       " 'XXX/shard-0006.tar',\n",
       " 'XXX/shard-0007.tar',\n",
       " 'XXX/shard-0008.tar']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shards = list(\n",
    "    braceexpand.braceexpand(shards)\n",
    ")  # better display for tqdm with known length"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75dba8e2",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1e8fb95",
   "metadata": {},
   "source": [
    "We load data using `webdataset`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ef5de9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (\n",
    "    wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
    "    .decode(\"rgb\", handler=wds.warn_and_continue)\n",
    "    .to_tuple(\"jpg\", \"txt\")  # assumes image is in `jpg` and caption in `txt`\n",
    "    .batched(total_bs)  # load in batch per worker (faster)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90981824",
   "metadata": {},
   "source": [
    "Note:\n",
    "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
    "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
    "* you can also filter out some items using `select`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "129c377d",
   "metadata": {},
   "source": [
    "We can now inspect our data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cac98cb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "images, captions = next(iter(ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd268fbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5acfc4d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "captions[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c24693c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "T.ToPILImage()(images[0].permute(2, 0, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3059ffb1",
   "metadata": {},
   "source": [
    "Finally we create our dataloader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c227c551",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = (\n",
    "    wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
    ")  # avoid partial batch at the end of each worker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a354472b",
   "metadata": {},
   "source": [
    "## Image encoder\n",
    "\n",
    "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a8b818",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
    "from flax.jax_utils import replicate\n",
    "\n",
    "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
    "vqgan_params = replicate(vqgan.params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62ad01c3",
   "metadata": {},
   "source": [
    "## Encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20357f74",
   "metadata": {},
   "source": [
    "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322a4619",
   "metadata": {},
   "outputs": [],
   "source": [
    "from flax.training.common_utils import shard\n",
    "from functools import partial\n",
    "\n",
    "\n",
    "@partial(jax.pmap, axis_name=\"batch\")\n",
    "def p_encode(batch, params):\n",
    "    # Not sure if we should `replicate` params, does not seem to have any effect\n",
    "    _, indices = vqgan.encode(batch, params=params)\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6c10d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "def encode_dataset(dataloader, output_dir, save_frequency):\n",
    "    output_dir.mkdir(parents=True, exist_ok=True)\n",
    "    all_captions = []\n",
    "    all_encoding = []\n",
    "    n_file = 1\n",
    "    for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
    "        images = images.numpy()\n",
    "        n = len(images) // 8 * 8\n",
    "        if n != len(images):\n",
    "            # get the max number of images we can (multiple of 8)\n",
    "            print(f\"Different sizes {n} vs {len(images)}\")\n",
    "            images = images[:n]\n",
    "            captions = captions[:n]\n",
    "        if not len(captions):\n",
    "            print(f\"No images/captions in batch...\")\n",
    "            continue\n",
    "        images = shard(images)\n",
    "        encoded = p_encode(images, vqgan_params)\n",
    "        encoded = encoded.reshape(-1, encoded.shape[-1])\n",
    "        all_captions.extend(captions)\n",
    "        all_encoding.extend(encoded.tolist())\n",
    "\n",
    "        # save files\n",
    "        if (idx + 1) % save_frequency == 0:\n",
    "            print(f\"Saving file {n_file}\")\n",
    "            batch_df = pd.DataFrame.from_dict(\n",
    "                {\"caption\": all_captions, \"encoding\": all_encoding}\n",
    "            )\n",
    "            batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
    "            all_captions = []\n",
    "            all_encoding = []\n",
    "            n_file += 1\n",
    "\n",
    "    if len(all_captions):\n",
    "        print(f\"Saving final file {n_file}\")\n",
    "        batch_df = pd.DataFrame.from_dict(\n",
    "            {\"caption\": all_captions, \"encoding\": all_encoding}\n",
    "        )\n",
    "        batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7704863d",
   "metadata": {},
   "outputs": [],
   "source": [
    "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8953dd84",
   "metadata": {},
   "source": [
    "----"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
